Skip to content

Add model metadata #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 4, 2025
Merged

Add model metadata #135

merged 32 commits into from
Jun 4, 2025

Conversation

willdumm
Copy link
Contributor

@willdumm willdumm commented May 2, 2025

This PR adds the following values to metadata of saved models:

  • multihit_model_name: expected to be a key in netam.pretrained.PRETRAINED_MULTIHIT_MODELS. Defaults to netam.models.DEFAULT_MULTIHIT_MODEL. For crepes saved without this data, defaults to None.
  • neutral_model_name: expected to be a named pretrained neutral model. Defaults to netam.models.DEFAULT_NEUTRAL_MODEL. For crepes saved without this data, defaults to ThriftyHumV0.2-59.
  • train_timestamp: a UTC timestamp taken at the time of model initialization, if not provided explicitly (e.g. 2025-05-01T22:05). For crepes saved without this data, defaults to old
  • model_type: either dnsm, dasm, or ddsm which must be provided at the time of model instantiation. For crepes saved without this data, defaults to unknown, and will throw warnings.

As hinted at above, I added a dictionary containing pretrained multihit models to netam.pretrained. These models can be accessed by name using netam.pretrained.load_multihit.

Requires companion PR https://github.com/matsengrp/dnsm-experiments-1/pull/132

@willdumm willdumm requested a review from Copilot May 2, 2025 18:55
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR extends model metadata to include multihit and neutral model settings and integrates these changes across tests and core model functions.

  • Updates tests to load and use multihit models via load_multihit.
  • Extends AbstractBinarySelectionModel and SingleValueBinarySelectionModel with new metadata (including model_type, train_timestamp, neutral_model_name, and multihit_model_name) and adjusts hyperparameter defaults.
  • Enhances framework functions (including add_shm_model_outputs_to_pcp_df and DXSMBurrito initialization) to verify model metadata consistency.

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_simulation.py Uses load_multihit to retrieve multihit model and adds tolerance in allclose check; reassigns train_dataset to val_dataset.
tests/test_multihit.py Updates model instantiation to pass model_type and generate multihit_model_name from model weights.
tests/test_dnsm.py, test_ddsm.py, test_dasm.py, test_ambiguous.py Integrates new parameter model_type and multihit_model into model/dataset creation.
netam/pretrained.py Introduces load_multihit and name_and_multihit_model_match for multihit model handling.
netam/models.py Extends metadata in model constructors and updates reinitialize_weights, to_weights, and from_weights methods.
netam/framework.py Adds default hyperparameter values for legacy models and filters sequences in add_shm_model_outputs_to_pcp_df.
netam/dxsm.py Implements metadata validation with warnings regarding model_type and multihit model consistency.

@willdumm willdumm requested a review from matsen May 2, 2025 19:15
Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few final todos 👍

@@ -66,13 +63,7 @@ def apply_multihit_correction(
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
corrections = torch.cat([torch.tensor([0.0]), log_hit_class_factors]).exp()
reshaped_corrections = corrections[per_parent_hit_class]
unnormalized_corrected_probs = clamp_probability(codon_probs * reshaped_corrections)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a refactor -- the forward method of the multihit model still sets the parent codon probability, but this allows the model to expose a method that adjusts codon probs but does not set the parent codon probability.

@willdumm
Copy link
Contributor Author

Leaving this here just to document that I did try implementing simulation probabilities with build_codon_mutsel. Here's the working function (it agrees with the version I am using exactly)

def codon_probs_of_parent_seq_new(
    selection_crepe, nt_sequence, branch_length, neutral_crepe=None, multihit_model=None
):
    """Calculate the predicted model probabilities of each codon at each site.

    Args:
        nt_sequence: A tuple of two strings, the heavy and light chain nucleotide
            sequences.
        branch_length: The branch length of the tree.
    Returns:
        a tuple of tensors of shape (L, 64) representing the predicted probabilities of each
        codon at each site.
    """
    if neutral_crepe is None:
        raise NotImplementedError("neutral_crepe is required.")

    if isinstance(nt_sequence, str) or len(nt_sequence) != 2:
        raise ValueError(
            "nt_sequence must be a pair of strings, with the first element being the heavy chain sequence and the second element being the light chain sequence."
        )

    aa_seqs = tuple(translate_sequences_mask_codons(nt_sequence))
    # We must mask any codons containing N's because we need neutral probs to
    # do simulation:
    mask = tuple(codon_mask_tensor_of(chain_nt_seq) for chain_nt_seq in nt_sequence)
    rates, csps = trimmed_shm_outputs_of_parent_pair(neutral_crepe, nt_sequence)

    selection_factors = selection_crepe([aa_seqs])[0]

    if selection_crepe.model.hyperparameters["output_dim"] == 1:
        # Need to upgrade single selection factor to 20 selection factors, all
        # equal except for the one for the parent sequence, which should be
        # 1 (0 in log space).
        new_selection_factors = []
        for aa_seq, old_selection_factors in zip(aa_seqs, selection_factors):
            if len(aa_seq) == 0:
                new_selection_factors.append(torch.empty(0, 20, dtype=old_selection_factors.dtype))
            else:
                parent_indices = aa_idx_tensor_of_str_ambig(aa_seq)
                # print(old_selection_factors)
                new_selection_factors.append(
                    # Selection factors are expected to be in linear space here
                    molevol.lift_to_per_aa_selection_factors(old_selection_factors, parent_indices)
                )
        selection_factors = tuple(new_selection_factors)

    parent_nt_idxs = tuple(
        nt_idx_tensor_of_str(nt_chain_seq.replace("N", "A")) for nt_chain_seq in nt_sequence
    )
    codon_probs = []
    for parent_idxs, nt_csps, nt_rates, sel_matrix in zip(parent_nt_idxs, csps, rates, selection_factors):
        if len(parent_idxs) > 0:
            nt_mut_probs = 1.0 - torch.exp(-branch_length * nt_rates)
            codon_mutsel, _ = molevol.build_codon_mutsel(
                parent_idxs.reshape(-1, 3),
                nt_mut_probs.reshape(-1, 3),
                nt_csps.reshape(-1, 3, 4),
                sel_matrix,
                multihit_model=multihit_model,
            )
            codon_probs.append(molevol.zero_stop_codon_probs(molevol.flatten_codons(clamp_probability(codon_mutsel))))
        else:
            codon_probs.append(torch.empty(0, 64, dtype=torch.float32))

    return tuple(codon_probs)

@willdumm willdumm requested a review from Copilot May 30, 2025 22:57
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces additional metadata fields to saved models and updates tests, fixtures, and model constructors accordingly. Key changes include:

  • Adding metadata keys (multihit_model_name, neutral_model_name, train_timestamp, model_type) to model initialization and hyperparameters.
  • Updating tests and fixtures across multiple files to accommodate the new metadata.
  • Enhancing pretrained model loading with a new multihit models dictionary and associated utility functions.

Reviewed Changes

Copilot reviewed 26 out of 26 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/* Update fixtures and test references for new metadata and defaults
netam/models.py Extend model constructors and hyperparameters with metadata
netam/pretrained.py Add PRETRAINED_MULTIHIT_MODELS dict and load_multihit function
netam/molevol.py, netam/hit_class.py Update processing of mutation probabilities with multihit support
netam/framework.py, others Various adjustments to integrate metadata into the workflow
Comments suppressed due to low confidence (1)

netam/models.py:580

  • Consider requiring 'model_type' as a mandatory argument instead of defaulting to None and issuing a warning, to enforce explicit model typing and simplify downstream logic.
def __init__(self, output_dim: int = 1, known_token_count: int = MAX_AA_TOKEN_IDX + 1, neutral_model_name: str = DEFAULT_NEUTRAL_MODEL, multihit_model_name: str = DEFAULT_MULTIHIT_MODEL, train_timestamp: str = None, model_type: str = None):

Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Some very minor things here.

netam/common.py Outdated
@@ -67,6 +67,10 @@ def clamp_probability(x: Tensor) -> Tensor:
return torch.clamp(x, min=SMALL_PROB, max=(1.0 - SMALL_PROB))


def clamp_probability_above(x: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this function name concordant with the function below. They are both above, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to clamp_probability_above_only

DEFAULT_NEUTRAL_MODEL = "ThriftyHumV0.2-59"
DEFAULT_MULTIHIT_MODEL = None
# # ATTENTION!!! when done with dnsm retrainings, switch back to this:
# DEFAULT_MULTIHIT_MODEL = "ThriftyHumV0.2-59-hc-tangshm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a calendar event or something to remind us? 😁

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just created one for Friday!

result = self.apply_multihit_correction(
parent_codon_idxs, uncorrected_codon_probs
)
# clamp only above to avoid summing a bunch of small fake values when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pull this out into a function with a name and a slightly clearer docstring? This is a little on the opaque side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments instead

netam/models.py Outdated
):
"""Apply the correction to the uncorrected codon probabilities.

Unlike `forward` this does
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line wrapping

@@ -320,7 +324,7 @@ def build_codon_mutsel(
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)
codon_probs = multihit_model.forward(parent_codon_idxs, codon_probs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have thought these were identical.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are, but I like the style of an explicit named method call for code searching purposes

@@ -383,7 +387,7 @@ def neutral_codon_probs(
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)
codon_probs = multihit_model.forward(parent_codon_idxs, codon_probs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

print(flat_log_codon_mutsel[diff_mask])
assert False

# adjusted_codon_probs = molevol.zero_stop_codon_probs(clamp_probability(adjusted_codon_probs.exp()).log())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not taken a close read of these functions-- I trust that they are doing what you want-- but perhaps you want to make a quick scan to tidy things up. Is this useful or cruft?

@willdumm willdumm merged commit 264877d into main Jun 4, 2025
2 checks passed
@willdumm willdumm deleted the 122-metadata branch June 4, 2025 20:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants